# pylint: disable=no-name-in-module

import copy

# import functools
import logging
from pathlib import Path

import hydra
import numpy as np
import torch
from omegaconf import DictConfig, OmegaConf
from tqdm import tqdm

from diffusion_bandit import utils
from diffusion_bandit.linear_ts_plotting import plot_baseline
from diffusion_bandit.linear_ts_utils import initialize_parameters, update_posterior
from diffusion_bandit.neural_networks.shape_reward_nets import (
    get_ground_truth_reward_model,
)

# from typing import Callable, Tuple


# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


@hydra.main(
    version_base=None, config_path="configs", config_name="linear_thompson_sampling"
)
def main(config: DictConfig) -> None:
    """
    Main function to execute linear Thompson sampling.
    """
    logger.info(OmegaConf.to_yaml(config))
    utils.seeding.seed_everything(config)

    # Setup
    model_load_path = Path(config.outputs_dir) / f"{config.names.score_model}.pth"
    saved_dict = torch.load(model_load_path, weights_only=False)
    device = torch.device(config.sampler.device)

    # Extract parameters from saved_dict
    dataset_config = saved_dict["dataset_config"]["dataset"]
    d_ext, d_int, radius, surface = [
        dataset_config[key] for key in ["d_ext", "d_int", "radius", "surface"]
    ]
    projector = saved_dict["projector"]
    x_data = saved_dict["x_data"]

    reward_model = get_ground_truth_reward_model(
        d_ext=d_ext,
        projector=projector,
        radius=radius,
        surface=surface,
        name=config.reward_model.name,
    )

    for mode in ["d_int", "d_ext"]:
        all_results = []
        for run_idx in range(config.thompson.num_runs):
            logger.info(f"Starting run {run_idx + 1}/{config.thompson.num_runs}")
            mean, cov, theta_gt = initialize_parameters(
                d_ext, config.thompson.prior_var, device
            )
            reward_model_gt = copy.deepcopy(reward_model)
            reward_model_gt.layer.weight.data = theta_gt.clone().detach()

            results = thompson_sampling(
                config,
                reward_model_gt,
                mean.clone().detach(),
                cov.clone().detach(),
                device,
                radius,
                projector,
                x_data,
                mode,
            )
            all_results.append(results)

        plot_baseline(all_results, projector, mode)

        results_save_path = Path(config.outputs_dir) / f"all_results_{mode}.pth"
        torch.save(all_results, results_save_path)
        logger.info(f"All results saved to {results_save_path}")


def thompson_sampling(
    config: DictConfig,
    reward_model: torch.nn.Module,
    mean: torch.Tensor,
    cov: torch.Tensor,
    device: torch.device,
    radius: float,
    projector: torch.Tensor,
    x_data: torch.Tensor,
    mode: str,
):
    """
    Perform Thompson sampling for optimization.

    Args:
        config (DictConfig): Configuration object containing all parameters for Thompson sampling.
        x_data (torch.Tensor): Input data tensor.
        sampler (Sampler): Sampler object for generating samples.
        score_model (torch.nn.Module): Score model for evaluating samples.
        reward_model (torch.nn.Module): Ground truth reward model.
        d_ext (int): Dimension of the parameter space.
        manifold_distance_fn (Callable): Function to compute distance to the manifold.
        device (torch.device): Device to use for tensor operations.

    Returns:
        Dict[str, Any]: Dictionary containing results of Thompson sampling.
    """
    noise_var = config.thompson.noise_var
    thompson_iterates = config.thompson.thompson_iterates

    theta_gt = reward_model.layer.weight.data.clone().detach()
    max_obtainable = radius * torch.linalg.norm(theta_gt @ projector)
    print(torch.max(reward_model(x_data)))
    print(max_obtainable)

    results = {
        "rewards_gt": [],
        "rewards_gt_noisy": [],
        "rewards_iterate": [],
        "posterior_mean": [],
        "posterior_cov": [],
        "theta_iterate": [],
        "theta_gt": reward_model.layer.weight.data.clone().cpu().numpy(),
        "max_obtainable": max_obtainable,
    }

    reward_model_iterate = copy.deepcopy(reward_model)

    for i in tqdm(range(thompson_iterates)):
        theta_iterate = torch.distributions.MultivariateNormal(mean, cov).sample()
        reward_model_iterate.layer.weight.data = theta_iterate.clone().detach()

        if mode == "d_int":
            maximizer = (
                radius
                * theta_iterate
                @ projector
                @ projector.t()
                / torch.linalg.norm(theta_iterate @ projector)
            )
        elif mode == "d_ext":
            maximizer = radius * theta_iterate / torch.linalg.norm(theta_iterate)

        with torch.no_grad():
            rewards_gt = reward_model(maximizer.reshape(1, -1))

        noisy_rewards = rewards_gt + torch.randn_like(rewards_gt) * (noise_var**0.5)

        # Update posterior
        mean, cov = update_posterior(
            mean=mean,
            cov=cov,
            x_new=maximizer.unsqueeze(0),
            y_new=noisy_rewards,
            noise_var=noise_var,
            device=device,
        )

        results["rewards_gt"].append(np.atleast_1d(rewards_gt.clone().cpu().numpy()))
        results["rewards_gt_noisy"].append(
            np.atleast_1d(noisy_rewards.clone().cpu().numpy())
        )
        results["posterior_mean"].append(mean.clone().cpu().numpy())
        results["posterior_cov"].append(cov.clone().cpu().numpy())
        results["theta_iterate"].append(theta_iterate.clone().cpu().numpy())

    print(rewards_gt)
    return results


if __name__ == "__main__":
    main()  # pylint: disable=no-value-for-parameter
